# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle

from fastdeploy.model_executor.ops.xpu import update_inputs

np.random.seed(2023)

bs = 48
max_bs = 64
max_input_length = 6144

stop_flags = np.random.randint(0, 2, max_bs).astype("bool")
not_need_stop = np.array([1], "bool")
seq_lens_this_time = np.zeros([bs], "int32")
seq_lens_encoder = np.zeros([max_bs], "int32")
seq_lens_decoder = np.zeros([max_bs], "int32")
for i in range(bs):
    if i % 2 == 0:
        seq_lens_encoder[i] = i
        seq_lens_this_time[i] = i
    else:
        seq_lens_decoder[i] = i
        seq_lens_this_time[i] = 1
input_ids_np = np.random.randint(1, 10, [max_bs, max_input_length], "int64")
stop_nums = np.array([max_bs], "int64")
next_tokens = np.random.randint(1, 10, [max_bs], "int64")
is_block_step = np.random.randint(0, 2, [max_bs]).astype("bool")

stop_flags = paddle.to_tensor(stop_flags)
not_need_stop = paddle.to_tensor(not_need_stop, place=paddle.CPUPlace())
seq_lens_this_time = paddle.to_tensor(seq_lens_this_time)
seq_lens_encoder = paddle.to_tensor(seq_lens_encoder)
seq_lens_decoder = paddle.to_tensor(seq_lens_decoder)
input_ids = paddle.to_tensor(input_ids_np)
stop_nums = paddle.to_tensor(stop_nums)
next_tokens = paddle.to_tensor(next_tokens)
is_block_step = paddle.to_tensor(is_block_step)

print("stop_flags:\n", stop_flags)
print("not_need_stop:\n", not_need_stop)
print("seq_lens_this_time:\n", seq_lens_this_time)
print("seq_lens_encoder:\n", seq_lens_encoder)
print("seq_lens_decoder:\n", seq_lens_decoder)
print("input_ids:\n", input_ids)
print("stop_nums:\n", stop_nums)
print("next_tokens:\n", next_tokens)
print("is_block_step:\n", is_block_step)

update_inputs(
    stop_flags,
    not_need_stop,
    seq_lens_this_time,
    seq_lens_encoder,
    seq_lens_decoder,
    input_ids,
    stop_nums,
    next_tokens,
    is_block_step,
)

print("-" * 50)
print("stop_flags:\n", stop_flags)
print("not_need_stop:\n", not_need_stop)
print("seq_lens_this_time:\n", seq_lens_this_time)
print("seq_lens_encoder:\n", seq_lens_encoder)
print("seq_lens_decoder:\n", seq_lens_decoder)
print("input_ids:\n", input_ids)
print("stop_nums:\n", stop_nums)
print("next_tokens:\n", next_tokens)

ref_not_need_stop_out = np.array([True])
ref_seq_lens_this_time_out = np.array(
    [
        0,
        0,
        1,
        0,
        0,
        1,
        0,
        1,
        1,
        1,
        0,
        1,
        1,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        1,
        1,
        0,
        1,
        1,
        0,
        1,
        1,
        0,
        0,
        0,
        1,
        1,
        0,
        1,
        0,
        0,
        1,
        0,
        1,
        0,
        0,
        1,
        0,
        0,
        1,
        1,
        1,
    ],
    "int32",
)
ref_seq_lens_encoder_out = np.array(
    [
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
    ],
    "int32",
)
ref_seq_lens_decoder_out = np.array(
    [
        0,
        0,
        2,
        0,
        0,
        6,
        0,
        8,
        8,
        10,
        0,
        12,
        12,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        20,
        22,
        0,
        24,
        24,
        0,
        26,
        28,
        0,
        0,
        0,
        32,
        32,
        0,
        34,
        0,
        0,
        38,
        0,
        40,
        0,
        0,
        42,
        0,
        0,
        46,
        46,
        48,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
    ],
    "int32",
)
input_ids_np[:, 0] = np.array(
    [
        6,
        5,
        9,
        8,
        6,
        2,
        8,
        1,
        3,
        1,
        3,
        6,
        9,
        8,
        1,
        9,
        1,
        8,
        8,
        6,
        7,
        6,
        5,
        3,
        5,
        9,
        3,
        6,
        3,
        9,
        8,
        8,
        8,
        8,
        4,
        8,
        7,
        4,
        2,
        3,
        5,
        8,
        4,
        2,
        5,
        6,
        8,
        9,
        6,
        7,
        4,
        2,
        4,
        6,
        2,
        3,
        4,
        9,
        7,
        2,
        1,
        8,
        7,
        8,
    ],
    "int64",
)

assert not_need_stop.numpy() == ref_not_need_stop_out, "Check not_need_stop failed."
assert np.all(seq_lens_this_time.numpy() == ref_seq_lens_this_time_out), "Check seq_lens_this_time failed."
assert np.all(seq_lens_encoder.numpy() == ref_seq_lens_encoder_out), "Check seq_lens_encoder failed."
assert np.all(seq_lens_decoder.numpy() == ref_seq_lens_decoder_out), "Check seq_lens_decoder failed."
assert np.all(input_ids.numpy() == input_ids_np), "Check input_ids failed."
